{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cross Entropy \n",
    "Let's consider the example described in Section 6.3 where we are trying to predict whether a given image contains a cat, dog, airplane or automobile. This is a 4 class classification problem. We know from the ground truth that the image is that of a cat. Thus $X_{gt} = [1, 0, 0, 0]$, where the 4 numbers represents the probability of the 4 classes respectively. Since this is the ground truth data, we are confident that the image is that of a cat. \n",
    "\n",
    "Let's say we have a model that predicts the probability of each of these classes. Let's call this prediction $X_{pred}$. A good prediction would have a high probability for cat and a low probability for the other classes. For example, $X_{good\\_pred} = [0.8, 0.15, 0.04, 0.01]$. On the other hand, a bad prediction would have a low probability for cat. For example, $X_{bad\\_pred} = [0.25, 0.25, 0.25, 0.25]$. \n",
    "\n",
    "Cross Entropy, denoted by $H_{C}$ gives us a way to quantify this measurement. It measures the dissimlarity between the two probability distributions. $H_{C}(X_{gt}, X_{good\\_pred})$ would be low, whereas $X_{gt}$ and $X_{bad\\_pred}$ would be high. It can be computed by using the following formula.\n",
    "\n",
    "$$ H_{C}(X_{gt}, X_{pred}) = -\\sum_{i=1}^{N} p_{gt}(i) \\log(p_{pred}(i)) $$\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def cross_entropy(X_gt, X_pred):\n",
    "    H_C = 0\n",
    "    for x_gt, x_pred in zip(X_gt, X_pred):\n",
    "        H_C += -1 * (x_gt * torch.log (x_pred))\n",
    "    return H_C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cross entropy for the good prediction: 0.223\n",
      "Cross entropy for the bad prediction: 1.386\n"
     ]
    }
   ],
   "source": [
    "X_gt = torch.Tensor([1., 0., 0., 0.])\n",
    "X_good_pred = torch.Tensor([0.8, 0.15, 0.04, 0.01])\n",
    "X_bad_pred = torch.Tensor([0.25, 0.25, 0.25, 0.25])\n",
    "\n",
    "H_C_good = cross_entropy(X_gt, X_good_pred)\n",
    "H_C_bad = cross_entropy(X_gt, X_bad_pred)\n",
    "\n",
    "# Note how the cross entropy for the bad prediction is much higher than that for the good prediction\n",
    "print('Cross entropy for the good prediction: {:.3f}'.format(H_C_good))\n",
    "print('Cross entropy for the bad prediction: {:.3f}'.format(H_C_bad))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
